CREATE OR REPLACE FUNCTION "pipeline"."fp_growth"("source_table" text, "out_table" varchar, "feature_cols" varchar, "min_frequency" int4, "min_support" float8)
  RETURNS "pg_catalog"."text" AS $BODY$
  import numpy as np
  import json
  import time
  import math
  from collections import defaultdict, namedtuple

  #初始化返回值
  result = {
    "status": 0,
    "error_msg": "success",
    "result": {
      "input_params": {
        "out_table": out_table,
        "source_table": source_table,
        "feature_cols": feature_cols,
        "min_frequency": min_frequency,
        "min_support": min_support
      },
      "output_params": [ 
        {
           "output_cols": ["_record_id_", "pattern", "frequency", "support"],
           "out_table_name": None 
        }
      ]
    }
  }

  def find_frequent_itemsets(transactions, minimum_support, include_support=True):

    items = defaultdict(lambda: 0)

    for transaction in transactions:
        for item in transaction:
            items[item] += 1

    items = dict((item, support) for item, support in items.items()
        if support >= minimum_support)

    def clean_transaction(transaction):
        transaction = filter(lambda v: v in items, transaction)
        transaction_list = list(transaction)
        transaction_list.sort(key=lambda v: items[v], reverse=True)
        return transaction_list

    #FPtree构建
    master = FPTree()
    for transaction in map(clean_transaction, transactions):
        master.add(transaction)

    def find_with_suffix(tree, suffix):
        for item, nodes in tree.items():
            support = sum(n.count for n in nodes)
            if support >= minimum_support and item not in suffix:

                found_set = [item] + suffix
                yield (found_set, support) if include_support else found_set

                cond_tree = conditional_tree_from_paths(tree.prefix_paths(item))
                for s in find_with_suffix(cond_tree, found_set):
                    yield s

    for itemset in find_with_suffix(master, []):
        yield itemset

  class FPTree(object):

    Route = namedtuple('Route', 'head tail')

    def __init__(self):
        self._root = FPNode(self, None, None)

        self._routes = {}

    @property
    def root(self):
        return self._root

    def add(self, transaction):
        point = self._root

        for item in transaction:
            next_point = point.search(item)
            if next_point:
                next_point.increment()
            else:
                next_point = FPNode(self, item)
                point.add(next_point)
                self._update_route(next_point)

            point = next_point

    def _update_route(self, point):
        assert self is point.tree

        try:
            route = self._routes[point.item]
            route[1].neighbor = point
            self._routes[point.item] = self.Route(route[0], point)
        except KeyError:
            self._routes[point.item] = self.Route(point, point)

    def items(self):

        for item in self._routes.keys():
            yield (item, self.nodes(item))

    def nodes(self, item):

        try:
            node = self._routes[item][0]
        except KeyError:
            return

        while node:
            yield node
            node = node.neighbor

    def prefix_paths(self, item):

        def collect_path(node):
            path = []
            while node and not node.root:
                path.append(node)
                node = node.parent
            path.reverse()
            return path

        return (collect_path(node) for node in self.nodes(item))

    def inspect(self):
        self.root.inspect(1)
        for item, nodes in self.items():
            for node in nodes:
                print('    %r' % node)
  
  def conditional_tree_from_paths(paths):
    tree = FPTree()
    condition_item = None
    items = set()

    for path in paths:
        if condition_item is None:
            condition_item = path[-1].item

        point = tree.root
        for node in path:
            next_point = point.search(node.item)
            if not next_point:
                items.add(node.item)
                count = node.count if node.item == condition_item else 0
                next_point = FPNode(tree, node.item, count)
                point.add(next_point)
                tree._update_route(next_point)
            point = next_point

    assert condition_item is not None

    for path in tree.prefix_paths(condition_item):
        count = path[-1].count
        for node in reversed(path[:-1]):
            node._count += count
    return tree

  class FPNode(object):

    def __init__(self, tree, item, count=1):
        self._tree = tree
        self._item = item
        self._count = count
        self._parent = None
        self._children = {}
        self._neighbor = None

    def add(self, child):

        if not isinstance(child, FPNode):
            raise TypeError("Can only add other FPNodes as children")

        if not child.item in self._children:
            self._children[child.item] = child
            child.parent = self

    def search(self, item):
        try:
            return self._children[item]
        except KeyError:
            return None

    def __contains__(self, item):
        return item in self._children

    @property
    def tree(self):
        return self._tree

    @property
    def item(self):
        return self._item

    @property
    def count(self):
        return self._count

    def increment(self):
        if self._count is None:
            raise ValueError("Root nodes have no associated count.")
        self._count += 1

    @property
    def root(self):
        return self._item is None and self._count is None

    @property
    def leaf(self):
        return len(self._children) == 0

    @property
    def parent(self):
        return self._parent

    @parent.setter
    def parent(self, value):
        if value is not None and not isinstance(value, FPNode):
            raise TypeError("A node must have an FPNode as a parent.")
        if value and value.tree is not self.tree:
            raise ValueError("Cannot have a parent from another tree.")
        self._parent = value

    @property
    def neighbor(self):
        return self._neighbor

    @neighbor.setter
    def neighbor(self, value):
        if value is not None and not isinstance(value, FPNode):
            raise TypeError("A node must have an FPNode as a neighbor.")
        if value and value.tree is not self.tree:
            raise ValueError("Cannot have a neighbor from another tree.")
        self._neighbor = value

    @property
    def children(self):
        return tuple(self._children.itervalues())

    def inspect(self, depth=0):
        for child in self.children:
            child.inspect(depth + 1)

    def __repr__(self):
        if self.root:
            return "<%s (root)>" % type(self).__name__
        return "<%s %r (%r)>" % (type(self).__name__, self.item, self.count)


  #读数据
  def get_data_sql(source_table, feature_col):  
    sql_str = "select \"%s\" from %s" %(feature_col, source_table)
    return sql_str

  def getTableMetaInfo(table_name): 
    tmps = table_name.split(".")
    if len(tmps) > 0:
      table_name = tmps[1]	
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
      meta[line['column_name']] = line['data_type']
    return meta 


  def validate(featureCols, meta):
    flag = True
    isNumer = False
    msg = "success"
    numberCnt = 0    
    numericTypes = ["smallint", "integer", "decimal", "numeric", "real", "double precision", "serial", "bigserial", "bigint"]
    for feature in featureCols:
      dataType = meta.get(feature.replace('"', "").replace("'", ""))
      if dataType == None:
        return False, isNumer, "key={} not exists".format(feature)
      if dataType.lower() in numericTypes:
        numberCnt += 1
    if numberCnt == 0:
      isNumer = False
    elif numberCnt == len(featureCols):
      isNumer = True
    else:
      flag = False
      msg = "features type is not same"
    return flag, isNumer, msg

  def createTmpTable(source_table, featureCols, isNumer):
    flag = True
    msg = "success"
    sqlTpl = """
      CREATE VIEW {} AS SELECT "pipeline"."{}"({}) AS event FROM {}
    """
    tmpTableName = "pipeline.view_fpgrowth_{}".format(int(time.time() * 10000))
    functionName = "cols_to_vec_numeric"
    if not isNumer:
      functionName = "cols_to_vec_char"
    sqlStr = sqlTpl.format(tmpTableName, functionName, ",".join(featureCols), source_table)
    try:
      plpy.execute(sqlStr)
    except Exception as e:
      msg = str(e) + " sql=" + sqlStr
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      return False, tmpTableName, msg
    return flag, tmpTableName, msg
        
  #创建输出表
  def saveOutTable(datas, out_table):
    value_tpl = "({}, '{}', {}, {})"
    #table_name = "%s_%s" %(out_table, int(time.time()))
    table_name = out_table    
    sql_str = "CREATE TABLE %s (_record_id_ int, pattern text[], frequency int, support float);" %(table_name)
    plpy.execute(sql_str)
    sql_str = "INSERT INTO %s (_record_id_, pattern, frequency, support) VALUES" %(table_name)
    _values = []
    for item in datas:
      tmp = value_tpl.format(item[0], item[1], item[2], item[3])
      _values.append(tmp)
    sql_str += ",".join(_values)
    plpy.execute(sql_str)
    return table_name
  
  def throwError(e):
    global result
    result["status"] = 500
    result["error_msg"] = str(e)
  
  #主程序
  def begin_alg(source_table, out_table, feature_cols, min_frequency, min_support):
    global result
    featureCols = feature_cols.split(",")
    sourceTable = source_table
    if source_table.startswith("create view") or source_table.startswith("CREATE VIEW"):
        try:
            rs = plpy.execute(source_table)
            sourceTable = source_table.split(" ")[2]
        except Exception as e:
            plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            msg = "sq={}, errorMsg={}".format(source_table, str(e))
            result["status"] = 500
            result["error_msg"] = msg
            return json.dumps(result)    
    meta = getTableMetaInfo(sourceTable)
    flag, isNumer, msg = validate(featureCols, meta)
    if not flag:
      throwError(msg)
      return json.dumps(result)
    flag, tmpTableName, msg = createTmpTable(sourceTable, featureCols, isNumer)
    if not flag:
      throwError(msg)
      return json.dumps(result)
    feature_col = "event"
    sql_str = get_data_sql(tmpTableName, feature_col)
    rv = None
    try:
      rv = plpy.execute(sql_str)
    except Exception as e:
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      throwError(e)
      return json.dumps(result)

    datas = []
    for item in rv:
      datas.append(item[feature_col])
    sup_frequency = math.ceil(min_support * len(datas))
    if min_frequency < sup_frequency:
      min_frequency = sup_frequency
    frequent_items = find_frequent_itemsets(datas, minimum_support = min_frequency, include_support=True)
    frequent_items = sorted(frequent_items, key=lambda i: i[1], reverse=True)
    res = []
    index = 0
    for item in frequent_items:
      index += 1
      string = str(item[0]).replace('[', '{')
      string = string.replace(']', '}')
      string = string.replace('\'', '\'\'')
      res.append([index, string, item[1], float(item[1]) / float(len(datas))])
    if len(res) == 0:
      result["error_msg"] = "No patters found." 
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      return json.dumps(result)
    saveRet = None
    try:
      saveRet = saveOutTable(res, out_table)
    except Exception as e:
      result["testMsg"] = res
      plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
      throwError(e)
      return json.dumps(result)
    plpy.execute("DROP VIEW IF EXISTS {}".format(tmpTableName))
    result["result"]["output_params"][0]["out_table_name"] = saveRet

    return json.dumps(result)
  
  if __name__ == '__main__':
    global result
    ret = begin_alg(source_table, out_table, feature_cols, min_frequency, min_support)
    return ret
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100